[Model] feat: FastGen DMD2-distilled Wan 2.1 pipelines (T2V, I2V) #2749
Conversation
Signed-off-by: ayushag <ayushag@nvidia.com>
Signed-off-by: ayushag <ayushag@nvidia.com>
Signed-off-by: ayushag <ayushag@nvidia.com>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
PR #2749 - [WIP] feat: fastgen model integration OVERALL: WIP (not blocking) Ready for full review when WIP removed. Preliminary scan: fastgen model integration for Wan2.2. 7 files, 375+ LOC. pre-commit failing (gate blocker). PR body empty. Fix pre-commit and add description before requesting review. |
|
BLOCKER scan: This PR is marked as [WIP] and has failing pre-commit checks. Please:
OVERALL: WIP WITH FAILING CHECKS VERDICT: REQUEST_CHANGES |
lishunyang12
left a comment
There was a problem hiding this comment.
Early review (WIP) -- fastgen DMD2 model integration
Thanks for the PR. The overall approach of subclassing the existing Wan 2.2 T2V / I2V pipelines for DMD2-distilled 4-step models is reasonable. Here are observations for the current state:
1. Heavy code duplication between T2V and I2V DMD2 classes
WanT2VDMD2Pipeline and WanI2VDMD2Pipeline share identical implementations for:
__init__(scheduler replacement)_verify_dmd2_request(entire method, ~40 lines)forwardtimestep-patching logic (monkey-patch + try/finally)- Class constants (
GUIDANCE_SCALE,NUM_INFERENCE_STEPS,DMD2_TIMESTEPS)
Please extract this into a mixin (e.g. DMD2PipelineMixin) that both classes inherit from. This avoids the maintenance burden of keeping two copies in sync and makes it trivial to add future DMD2 variants (e.g. VACE).
2. Monkey-patching scheduler.set_timesteps is fragile
In forward(), you replace self.scheduler.set_timesteps with a closure and restore it in a finally block. This approach has problems:
- Not thread-safe / async-safe: if two requests are in flight concurrently on the same pipeline instance, they will race on the scheduler's method.
- Fragile against parent refactors: if the parent ever calls
set_timestepsdifferently (e.g. via a local reference), the patch silently breaks.
Consider instead:
- Overriding the parent's denoising loop method directly, or
- Calling
self.scheduler.set_timesteps(timesteps=self.DMD2_TIMESTEPS, device=...)once after the parent calls it (by wrapping only the relevant section), or - Using
FlowMatchEulerDiscreteScheduler's built-intimestepsparameter atset_timestepstime and overridingnum_inference_stepsin the super call.
3. NUM_INFERENCE_STEPS class constant is declared but never used
You define NUM_INFERENCE_STEPS = 4 on both classes but never reference it. The forward() signature defaults to num_inference_steps: int = 4 as a separate literal. Either use the constant in the default or remove it to avoid confusion.
4. _verify_dmd2_request mutates the request in-place silently
The method modifies req.sampling_params and req.prompts in-place. While there are log warnings, consider:
- Documenting in the method docstring that the request is mutated (not just verified -- the name
_verify_*is misleading for a method that modifies data). A name like_sanitize_dmd2_requestwould be more accurate. - Returning the modified request or a copy, so callers see that mutation happens. In-place mutation of shared request objects can cause subtle bugs if the same request is reused.
5. p.get("negative_prompt") is falsy for empty strings
In _verify_dmd2_request:
if isinstance(p, dict) and p.get("negative_prompt"):This will not strip negative_prompt if its value is "" (empty string). If the intent is to remove the key entirely when present, use "negative_prompt" in p instead.
6. Test helper uses object.__new__ to skip __init__ -- brittle
In test_wan_dmd2_request_sanitization.py, _make_pipeline does:
pipeline = object.__new__(cls)
torch.nn.Module.__init__(pipeline)This creates a pipeline without running any __init__, so _verify_dmd2_request works only because it doesn't touch any instance state set in __init__. If the method ever accesses self.GUIDANCE_SCALE (which it already does), this works only because it's a class attribute. This is fragile -- consider using the mock-based approach from the scheduler test file instead, which is more robust.
7. Missing import in scheduler test
In test_wan_dmd2_scheduler.py, line 40:
from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParamsThis import is placed in the middle of the file (after fixtures, before test functions). Move it to the top with the other imports for consistency.
8. Unused import in scheduler test
import inspect on line 2 of test_wan_dmd2_scheduler.py is never used. Remove it.
9. Registry entries look correct
The additions to registry.py (pipeline registry, pre/post-process function mappings) look correct and consistent with existing entries. The DMD2 pipelines appropriately reuse the same pre/post-process functions as their parent pipelines.
10. PR description is empty
Please fill in the Purpose, Test Plan, and Test Result sections. Even for WIP, a brief description of the DMD2 distillation approach and which models this targets (model hub IDs) would help reviewers.
Overall, the core idea is sound. The main actionable items before this is ready for merge are: (1) extract the duplicated DMD2 logic into a mixin, (2) find a less fragile approach than monkey-patching set_timesteps, and (3) minor cleanup items above.
Signed-off-by: ayushag <ayushag@nvidia.com>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
@lishunyang12 Can you review once again ? |
|
@princepride PTAL |
|
There will be more dmd2 models based of different base models, so they will essentially use DMD2PipelineMixin. Does it make sense to move these to separate folder |
Signed-off-by: ayushag <ayushag@nvidia.com>
let's wait until more models:) this deserves a new RFC |
Signed-off-by: ayushag <ayushag@nvidia.com>
|
can you run the benchmark diffusion and paste the results? |
We have more models in progress, I will write a light weight RFC and update here. Am adding classes for other models as well. |
Signed-off-by: ayushag <ayushag@nvidia.com>
|
@hsliuustc0106 Thanks for merging this. I will follow up on this PR with required documentation and some benchmarks. |
…lm-project#2749) Signed-off-by: ayushag <ayushag@nvidia.com> Signed-off-by: nainiu258 <cperfect02@163.com>
…lm-project#2749) Signed-off-by: ayushag <ayushag@nvidia.com>
…lm-project#2749) Signed-off-by: ayushag <ayushag@nvidia.com>
…lm-project#2749) Signed-off-by: ayushag <ayushag@nvidia.com>
…lm-project#2749) Signed-off-by: ayushag <ayushag@nvidia.com>
…lm-project#2749) Signed-off-by: ayushag <ayushag@nvidia.com>
Purpose
Add support for FastGen DMD2-distilled 4-step models in vllm-omni. FastGen (NVlabs/FastGen) is NVIDIA's framework for training fast generative models via distillation, including DMD2.
DMD2-distilled models run in a small, fixed number of inference steps with guidance baked in (CFG not needed), substantially faster than the multi-step teacher while reusing the base model's text encoder / VAE / tokenizer.
Shared DMD2 module
DMD2EulerScheduler(diffusion/models/schedulers/scheduling_dmd2_euler.py) — subclass ofFlowMatchEulerDiscreteSchedulerthat always returns the fixed DMD2 training timestep schedule, ignoring caller-passednum_inference_steps/sigmas/mu.DMD2PipelineMixin(diffusion/models/dmd2/mixin.py) — readsdmd2_denoising_timesteps,dmd2_num_inference_steps,dmd2_guidance_scale,dmd2_scheduler_shiftfrommodel_index.json; sanitizes incoming requests to drop CFG / negative-prompt fields that don't apply._load_jsonmoved todiffusion/models/utils.pyas a shared helper (previously duplicated inline across pipelines).Model-specific stubs
Per-base-family 5-line classes composing the mixin with the existing base pipeline:
WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline)WanI2VDMD2Pipeline(DMD2PipelineMixin, Wan22I2VPipeline)LTX2T2VDMD2Pipeline(DMD2PipelineMixin, LTX2Pipeline)LTX2I2VDMD2Pipeline(DMD2PipelineMixin, LTX2ImageToVideoPipeline)All four are registered in
diffusion/registry.py(pipeline + post-process) and exported from their respective package__init__.py.Future DMD2 variants (additional Wan 2.x / LTX-2 / LTX-2.3 checkpoints) only need a new
model_index.json— no code change required. New base-pipeline families require one ~5-line stub + registry entries.Test Plan
Plus a local end-to-end smoke test:
vllm-omni serve <wan-dmd2-ckpt> --omniand aPOST /v1/videoscurl round-trip.Test Result
Unit tests parametrized over all four DMD2 pipelines (Wan T2V/I2V, LTX-2 T2V/I2V) — all passing.